import os

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import sys

cwd = os.getcwd()
sys.path.append(cwd.replace('/interface', ''))
print(sys.path)

from player_ranking.player_evaluation_metric import run_comparison_player_evaluation
from agent import SportsAgent
from generic.data_util import load_config, read_args


def test(args):
    mode = 'test'
    config, debug_mode, log_file_path = load_config(args)
    if debug_mode:
        debug_msg = 'debug_'
    else:
        debug_msg = ''
    # rank_metric = config['general']['task']
    rank_metric = 'SI'
    # if rank_metric == 'PM':
    #     pass
    # elif rank_metric == 'GIM':
    #     test_train_rate = 0.8
    #     test_gamma = 1
    #     max_trace_length = 3
    #     apply_dynamic_trace_length = False
    #     test_apply_rnn = True
    #     test_apply_resnet = False
    #     test_cut_at_goal = True
    #     date = 'Oct-29-2021'
    #     iteration = 'test'
    #
    #     config['general']['model']['apply_rnn'] = test_apply_rnn
    #     config['general']['model']['apply_resnet'] = test_apply_resnet
    #     config['general']['training']['gamma'] = test_gamma
    #     config['general']['training']['cut_at_goal'] = test_cut_at_goal
    #     config['general']['training']['train_rate'] = test_train_rate
    #     config['general']['model']['apply_dynamic_trace_length'] = apply_dynamic_trace_length
    #     config['general']['model']['max_trace_length'] = max_trace_length
    #     config['general']['use_cuda'] = False

    if log_file_path is not None:
        log_file = open(log_file_path, 'w')
    else:
        log_file = None

    agent = SportsAgent(config=config, log_file=log_file)

    run_comparison_player_evaluation(agent=agent,
                                     rank_metric=rank_metric,
                                     model_save_path=None,
                                     log_file=log_file,
                                     mode=mode,
                                     sanity_check_msg='',
                                     debug_msg=debug_msg,
                                     debug_mode=debug_mode)



if __name__ == "__main__":
    args = read_args()
    test(args)
